{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "a8dbb4e8",
   "metadata": {},
   "source": [
    "# 🧪 Survey Synthetic Dataset Generator — Week 3 Task"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8d86f629",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "import os, re, json, time, uuid, math, random\n",
    "from datetime import datetime, timedelta\n",
    "from typing import List, Dict, Any\n",
    "import numpy as np, pandas as pd\n",
    "import pandera.pandas as pa\n",
    "random.seed(7); np.random.seed(7)\n",
    "print(\"✅ Base libraries ready. Pandera available:\", pa is not None)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f196ae73",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def extract_strict_json(text: str):\n",
    "    \"\"\"Improved JSON extraction with multiple fallback strategies\"\"\"\n",
    "    if text is None:\n",
    "        raise ValueError(\"Empty model output.\")\n",
    "    \n",
    "    t = text.strip()\n",
    "    \n",
    "    # Strategy 1: Direct JSON parsing\n",
    "    try:\n",
    "        obj = json.loads(t)\n",
    "        if isinstance(obj, list):\n",
    "            return obj\n",
    "        elif isinstance(obj, dict):\n",
    "            for key in (\"rows\",\"data\",\"items\",\"records\",\"results\"):\n",
    "                if key in obj and isinstance(obj[key], list):\n",
    "                    return obj[key]\n",
    "            if all(isinstance(k, str) and k.isdigit() for k in obj.keys()):\n",
    "                return [obj[k] for k in sorted(obj.keys(), key=int)]\n",
    "    except json.JSONDecodeError:\n",
    "        pass\n",
    "    \n",
    "    # Strategy 2: Extract JSON from code blocks\n",
    "    if t.startswith(\"```\"):\n",
    "        t = re.sub(r\"^```(?:json)?\\s*|\\s*```$\", \"\", t, flags=re.IGNORECASE|re.MULTILINE).strip()\n",
    "    \n",
    "    # Strategy 3: Find JSON array in text\n",
    "    start, end = t.find('['), t.rfind(']')\n",
    "    if start == -1 or end == -1 or end <= start:\n",
    "        raise ValueError(\"No JSON array found in model output.\")\n",
    "    \n",
    "    t = t[start:end+1]\n",
    "    \n",
    "    # Strategy 4: Fix common JSON issues\n",
    "    t = re.sub(r\",\\s*([\\]}])\", r\"\\1\", t)  # Remove trailing commas\n",
    "    t = re.sub(r\"\\bNaN\\b|\\bInfinity\\b|\\b-Infinity\\b\", \"null\", t)  # Replace NaN/Infinity\n",
    "    t = t.replace(\"\\u00a0\", \" \").replace(\"\\u200b\", \"\")  # Remove invisible characters\n",
    "    \n",
    "    try:\n",
    "        return json.loads(t)\n",
    "    except json.JSONDecodeError as e:\n",
    "        raise ValueError(f\"Could not parse JSON: {str(e)}. Text: {t[:200]}...\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3670fa0d",
   "metadata": {},
   "source": [
    "## 1) Configuration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d16bd03a",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "CFG = {\n",
    "    \"rows\": 800,\n",
    "    \"datetime_range\": {\"start\": \"2024-01-01\", \"end\": \"2025-10-01\", \"fmt\": \"%Y-%m-%d %H:%M:%S\"},\n",
    "    \"fields\": [\n",
    "        {\"name\": \"response_id\", \"type\": \"uuid4\"},\n",
    "        {\"name\": \"respondent_id\", \"type\": \"int\", \"min\": 10000, \"max\": 99999},\n",
    "        {\"name\": \"submitted_at\", \"type\": \"datetime\"},\n",
    "        {\"name\": \"country\", \"type\": \"enum\", \"values\": [\"KE\",\"UG\",\"TZ\",\"RW\",\"NG\",\"ZA\"], \"probs\": [0.50,0.10,0.12,0.05,0.15,0.08]},\n",
    "        {\"name\": \"language\", \"type\": \"enum\", \"values\": [\"en\",\"sw\"], \"probs\": [0.85,0.15]},\n",
    "        {\"name\": \"device\", \"type\": \"enum\", \"values\": [\"android\",\"ios\",\"web\"], \"probs\": [0.60,0.25,0.15]},\n",
    "        {\"name\": \"age\", \"type\": \"int\", \"min\": 18, \"max\": 70},\n",
    "        {\"name\": \"gender\", \"type\": \"enum\", \"values\": [\"female\",\"male\",\"nonbinary\",\"prefer_not_to_say\"], \"probs\": [0.49,0.49,0.01,0.01]},\n",
    "        {\"name\": \"education\", \"type\": \"enum\", \"values\": [\"primary\",\"secondary\",\"diploma\",\"bachelor\",\"postgraduate\"], \"probs\": [0.08,0.32,0.18,0.30,0.12]},\n",
    "        {\"name\": \"income_band\", \"type\": \"enum\", \"values\": [\"low\",\"lower_mid\",\"upper_mid\",\"high\"], \"probs\": [0.28,0.42,0.23,0.07]},\n",
    "        {\"name\": \"completion_seconds\", \"type\": \"float\", \"min\": 60, \"max\": 1800, \"distribution\": \"lognormal\"},\n",
    "        {\"name\": \"attention_passed\", \"type\": \"bool\"},\n",
    "        {\"name\": \"q_quality\", \"type\": \"int\", \"min\": 1, \"max\": 5},\n",
    "        {\"name\": \"q_value\", \"type\": \"int\", \"min\": 1, \"max\": 5},\n",
    "        {\"name\": \"q_ease\", \"type\": \"int\", \"min\": 1, \"max\": 5},\n",
    "        {\"name\": \"q_support\", \"type\": \"int\", \"min\": 1, \"max\": 5},\n",
    "        {\"name\": \"nps\", \"type\": \"int\", \"min\": 0, \"max\": 10},\n",
    "        {\"name\": \"is_detractor\", \"type\": \"bool\"}\n",
    "    ]\n",
    "}\n",
    "print(\"Loaded config for\", CFG[\"rows\"], \"rows and\", len(CFG[\"fields\"]), \"fields.\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7da1f429",
   "metadata": {},
   "source": [
    "## 2) Helpers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d2f5fdff",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def sample_enum(values, probs=None, size=None):\n",
    "    values = list(values)\n",
    "    if probs is None:\n",
    "        probs = [1.0 / len(values)] * len(values)\n",
    "    return np.random.choice(values, p=probs, size=size)\n",
    "\n",
    "def sample_numeric(field_cfg, size=1):\n",
    "    t = field_cfg[\"type\"]\n",
    "    if t == \"int\":\n",
    "        lo, hi = int(field_cfg[\"min\"]), int(field_cfg[\"max\"])\n",
    "        dist = field_cfg.get(\"distribution\", \"uniform\")\n",
    "        if dist == \"uniform\":\n",
    "            return np.random.randint(lo, hi + 1, size=size)\n",
    "        elif dist == \"normal\":\n",
    "            mu = (lo + hi) / 2.0\n",
    "            sigma = (hi - lo) / 6.0\n",
    "            out = np.random.normal(mu, sigma, size=size)\n",
    "            return np.clip(out, lo, hi).astype(int)\n",
    "        else:\n",
    "            return np.random.randint(lo, hi + 1, size=size)\n",
    "    elif t == \"float\":\n",
    "        lo, hi = float(field_cfg[\"min\"]), float(field_cfg[\"max\"])\n",
    "        dist = field_cfg.get(\"distribution\", \"uniform\")\n",
    "        if dist == \"uniform\":\n",
    "            return np.random.uniform(lo, hi, size=size)\n",
    "        elif dist == \"normal\":\n",
    "            mu = (lo + hi) / 2.0\n",
    "            sigma = (hi - lo) / 6.0\n",
    "            return np.clip(np.random.normal(mu, sigma, size=size), lo, hi)\n",
    "        elif dist == \"lognormal\":\n",
    "            mu = math.log(max(1e-3, (lo + hi) / 2.0))\n",
    "            sigma = 0.75\n",
    "            out = np.random.lognormal(mu, sigma, size=size)\n",
    "            return np.clip(out, lo, hi)\n",
    "        else:\n",
    "            return np.random.uniform(lo, hi, size=size)\n",
    "    else:\n",
    "        raise ValueError(\"Unsupported numeric type\")\n",
    "\n",
    "def sample_datetime(start: str, end: str, size=1, fmt=\"%Y-%m-%d %H:%M:%S\"):\n",
    "    s = datetime.fromisoformat(start)\n",
    "    e = datetime.fromisoformat(end)\n",
    "    total = int((e - s).total_seconds())\n",
    "    r = np.random.randint(0, total, size=size)\n",
    "    return [(s + timedelta(seconds=int(x))).strftime(fmt) for x in r]\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5f24111a",
   "metadata": {},
   "source": [
    "## 3) Rule-based Generator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cd61330d",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def generate_rule_based(CFG: Dict[str, Any]) -> pd.DataFrame:\n",
    "    n = CFG[\"rows\"]\n",
    "    dt_cfg = CFG.get(\"datetime_range\", {\"start\":\"2024-01-01\",\"end\":\"2025-10-01\",\"fmt\":\"%Y-%m-%d %H:%M:%S\"})\n",
    "    data = {}\n",
    "    for f in CFG[\"fields\"]:\n",
    "        name, t = f[\"name\"], f[\"type\"]\n",
    "        if t == \"uuid4\":\n",
    "            data[name] = [str(uuid.uuid4()) for _ in range(n)]\n",
    "        elif t in (\"int\",\"float\"):\n",
    "            data[name] = sample_numeric(f, size=n)\n",
    "        elif t == \"enum\":\n",
    "            data[name] = sample_enum(f[\"values\"], f.get(\"probs\"), size=n)\n",
    "        elif t == \"datetime\":\n",
    "            data[name] = sample_datetime(dt_cfg[\"start\"], dt_cfg[\"end\"], size=n, fmt=dt_cfg[\"fmt\"])\n",
    "        elif t == \"bool\":\n",
    "            data[name] = np.random.rand(n) < 0.9  # 90% True\n",
    "        else:\n",
    "            data[name] = [None]*n\n",
    "    df = pd.DataFrame(data)\n",
    "\n",
    "    # Derive NPS roughly from likert questions\n",
    "    if set([\"q_quality\",\"q_value\",\"q_ease\",\"q_support\"]).issubset(df.columns):\n",
    "        likert_avg = df[[\"q_quality\",\"q_value\",\"q_ease\",\"q_support\"]].mean(axis=1)\n",
    "        df[\"nps\"] = np.clip(np.round((likert_avg - 1.0) * (10.0/4.0) + np.random.normal(0, 1.2, size=n)), 0, 10).astype(int)\n",
    "\n",
    "    # Heuristic target: is_detractor more likely when completion high & attention failed\n",
    "    if \"is_detractor\" in df.columns:\n",
    "        base = 0.25\n",
    "        comp = df.get(\"completion_seconds\", pd.Series(np.zeros(n)))\n",
    "        attn = pd.Series(df.get(\"attention_passed\", np.ones(n))).astype(bool)\n",
    "        boost = (comp > 900).astype(int) + (~attn).astype(int)\n",
    "        p = np.clip(base + 0.15*boost, 0.01, 0.95)\n",
    "        df[\"is_detractor\"] = np.random.rand(n) < p\n",
    "\n",
    "    return df\n",
    "\n",
    "df_rule = generate_rule_based(CFG)\n",
    "df_rule.head()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dd9eff20",
   "metadata": {},
   "source": [
    "## 4) Validation (Pandera optional)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9a4ef86a",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def build_pandera_schema(CFG):\n",
    "    if pa is None:\n",
    "        return None\n",
    "    cols = {}\n",
    "    for f in CFG[\"fields\"]:\n",
    "        t, name = f[\"type\"], f[\"name\"]\n",
    "        if t == \"int\": cols[name] = pa.Column(int)\n",
    "        elif t == \"float\": cols[name] = pa.Column(float)\n",
    "        elif t == \"enum\": cols[name] = pa.Column(object)\n",
    "        elif t == \"datetime\": cols[name] = pa.Column(object)\n",
    "        elif t == \"uuid4\": cols[name] = pa.Column(object)\n",
    "        elif t == \"bool\": cols[name] = pa.Column(bool)\n",
    "        else: cols[name] = pa.Column(object)\n",
    "    return pa.DataFrameSchema(cols) if pa is not None else None\n",
    "\n",
    "def validate_df(df, CFG):\n",
    "    schema = build_pandera_schema(CFG)\n",
    "    if schema is None:\n",
    "        return df, {\"engine\":\"basic\",\"valid_rows\": len(df), \"invalid_rows\": 0}\n",
    "    try:\n",
    "        v = schema.validate(df, lazy=True)\n",
    "        return v, {\"engine\":\"pandera\",\"valid_rows\": len(v), \"invalid_rows\": 0}\n",
    "    except Exception as e:\n",
    "        print(\"Validation error:\", e)\n",
    "        return df, {\"engine\":\"pandera\",\"valid_rows\": len(df), \"invalid_rows\": 0, \"notes\": \"Non-strict mode.\"}\n",
    "\n",
    "validated_rule, report_rule = validate_df(df_rule, CFG)\n",
    "print(report_rule)\n",
    "validated_rule.head()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d5f1d93a",
   "metadata": {},
   "source": [
    "## 5) Save"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "73626b4c",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "from pathlib import Path\n",
    "out = Path(\"data\"); out.mkdir(exist_ok=True)\n",
    "ts = datetime.utcnow().strftime(\"%Y%m%dT%H%M%SZ\")\n",
    "csv_path = out / f\"survey_rule_{ts}.csv\"\n",
    "validated_rule.to_csv(csv_path, index=False)\n",
    "print(\"Saved:\", csv_path.as_posix())\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "87c89b51",
   "metadata": {},
   "source": [
    "## 6) Optional: LLM Generator (JSON mode, retry & strict parsing)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "24e94771",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Fixed LLM Generation Functions\n",
    "def create_survey_prompt(CFG, n_rows=50):\n",
    "    \"\"\"Create a clear, structured prompt for survey data generation\"\"\"\n",
    "    fields_desc = []\n",
    "    for field in CFG['fields']:\n",
    "        name = field['name']\n",
    "        field_type = field['type']\n",
    "        \n",
    "        if field_type == 'int':\n",
    "            min_val = field.get('min', 0)\n",
    "            max_val = field.get('max', 100)\n",
    "            fields_desc.append(f\"  - {name}: integer between {min_val} and {max_val}\")\n",
    "        elif field_type == 'float':\n",
    "            min_val = field.get('min', 0.0)\n",
    "            max_val = field.get('max', 100.0)\n",
    "            fields_desc.append(f\"  - {name}: float between {min_val} and {max_val}\")\n",
    "        elif field_type == 'enum':\n",
    "            values = field.get('values', [])\n",
    "            fields_desc.append(f\"  - {name}: one of {values}\")\n",
    "        elif field_type == 'bool':\n",
    "            fields_desc.append(f\"  - {name}: boolean (true/false)\")\n",
    "        elif field_type == 'uuid4':\n",
    "            fields_desc.append(f\"  - {name}: UUID string\")\n",
    "        elif field_type == 'datetime':\n",
    "            fmt = field.get('fmt', '%Y-%m-%d %H:%M:%S')\n",
    "            fields_desc.append(f\"  - {name}: datetime string in format {fmt}\")\n",
    "        else:\n",
    "            fields_desc.append(f\"  - {name}: {field_type}\")\n",
    "    \n",
    "    prompt = f\"\"\"Generate {n_rows} rows of realistic survey response data.\n",
    "\n",
    "Schema:\n",
    "{chr(10).join(fields_desc)}\n",
    "\n",
    "CRITICAL REQUIREMENTS:\n",
    "- Return a JSON object with a \"responses\" key containing an array\n",
    "- Each object in the array must have all required fields\n",
    "- Use realistic, diverse values for survey responses\n",
    "- No trailing commas\n",
    "- No comments or explanations\n",
    "\n",
    "Output format: JSON object with \"responses\" array containing exactly {n_rows} objects.\n",
    "\n",
    "Example structure:\n",
    "{{\n",
    "  \"responses\": [\n",
    "    {{\n",
    "      \"response_id\": \"uuid-string\",\n",
    "      \"respondent_id\": 12345,\n",
    "      \"submitted_at\": \"2024-01-01 12:00:00\",\n",
    "      \"country\": \"KE\",\n",
    "      \"language\": \"en\",\n",
    "      \"device\": \"android\",\n",
    "      \"age\": 25,\n",
    "      \"gender\": \"female\",\n",
    "      \"education\": \"bachelor\",\n",
    "      \"income_band\": \"upper_mid\",\n",
    "      \"completion_seconds\": 300.5,\n",
    "      \"attention_passed\": true,\n",
    "      \"q_quality\": 4,\n",
    "      \"q_value\": 3,\n",
    "      \"q_ease\": 5,\n",
    "      \"q_support\": 4,\n",
    "      \"nps\": 8,\n",
    "      \"is_detractor\": false\n",
    "    }},\n",
    "    ...\n",
    "  ]\n",
    "}}\n",
    "\n",
    "IMPORTANT: Return ONLY the JSON object with \"responses\" key, nothing else.\"\"\"\n",
    "    \n",
    "    return prompt\n",
    "\n",
    "def repair_truncated_json(content):\n",
    "    \"\"\"Attempt to repair truncated JSON responses\"\"\"\n",
    "    content = content.strip()\n",
    "    \n",
    "    # If it starts with { but doesn't end with }, try to close it\n",
    "    if content.startswith('{') and not content.endswith('}'):\n",
    "        # Find the last complete object in the responses array\n",
    "        responses_start = content.find('\"responses\": [')\n",
    "        if responses_start != -1:\n",
    "            # Find the last complete object\n",
    "            brace_count = 0\n",
    "            last_complete_pos = -1\n",
    "            in_string = False\n",
    "            escape_next = False\n",
    "            \n",
    "            for i, char in enumerate(content[responses_start:], responses_start):\n",
    "                if escape_next:\n",
    "                    escape_next = False\n",
    "                    continue\n",
    "                    \n",
    "                if char == '\\\\':\n",
    "                    escape_next = True\n",
    "                    continue\n",
    "                    \n",
    "                if char == '\"' and not escape_next:\n",
    "                    in_string = not in_string\n",
    "                    continue\n",
    "                    \n",
    "                if not in_string:\n",
    "                    if char == '{':\n",
    "                        brace_count += 1\n",
    "                    elif char == '}':\n",
    "                        brace_count -= 1\n",
    "                        if brace_count == 0:\n",
    "                            last_complete_pos = i\n",
    "                            break\n",
    "            \n",
    "            if last_complete_pos != -1:\n",
    "                # Truncate at the last complete object and close the JSON\n",
    "                repaired = content[:last_complete_pos + 1] + '\\n  ]\\n}'\n",
    "                print(f\"🔧 Repaired JSON: truncated at position {last_complete_pos}\")\n",
    "                return repaired\n",
    "    \n",
    "    return content\n",
    "\n",
    "def fixed_llm_generate_batch(CFG, n_rows=50):\n",
    "    \"\"\"Fixed LLM generation with better prompt and error handling\"\"\"\n",
    "    if not os.getenv('OPENAI_API_KEY'):\n",
    "        print(\"No OpenAI API key, using rule-based fallback\")\n",
    "        tmp = dict(CFG); tmp['rows'] = n_rows\n",
    "        return generate_rule_based(tmp)\n",
    "    \n",
    "    try:\n",
    "        from openai import OpenAI\n",
    "        client = OpenAI()\n",
    "        \n",
    "        prompt = create_survey_prompt(CFG, n_rows)\n",
    "        \n",
    "        print(f\"🔄 Generating {n_rows} survey responses with LLM...\")\n",
    "        \n",
    "        # Calculate appropriate max_tokens based on batch size\n",
    "        # Roughly 200-300 tokens per row, with some buffer\n",
    "        estimated_tokens = n_rows * 300 + 500  # Buffer for JSON structure\n",
    "        max_tokens = min(max(estimated_tokens, 2000), 8000)  # Between 2k-8k tokens\n",
    "        \n",
    "        print(f\"📊 Using max_tokens: {max_tokens} (estimated: {estimated_tokens})\")\n",
    "        \n",
    "        response = client.chat.completions.create(\n",
    "            model='gpt-4o-mini',\n",
    "            messages=[\n",
    "                {'role': 'system', 'content': 'You are a data generation expert. Generate realistic survey data in JSON format. Always return complete, valid JSON.'},\n",
    "                {'role': 'user', 'content': prompt}\n",
    "            ],\n",
    "            temperature=0.3,\n",
    "            max_tokens=max_tokens,\n",
    "            response_format={'type': 'json_object'}\n",
    "        )\n",
    "        \n",
    "        content = response.choices[0].message.content\n",
    "        print(f\"📝 Raw response length: {len(content)} characters\")\n",
    "        \n",
    "        # Check if response appears truncated\n",
    "        if not content.strip().endswith('}') and not content.strip().endswith(']'):\n",
    "            print(\"⚠️ Response appears truncated, attempting repair...\")\n",
    "            content = repair_truncated_json(content)\n",
    "        \n",
    "        # Try to extract JSON with improved logic\n",
    "        try:\n",
    "            data = json.loads(content)\n",
    "            print(f\"🔍 Parsed JSON type: {type(data)}\")\n",
    "            \n",
    "            if isinstance(data, list):\n",
    "                df = pd.DataFrame(data)\n",
    "                print(f\"📊 Direct array: {len(df)} rows\")\n",
    "            elif isinstance(data, dict):\n",
    "                # Check for common keys that might contain the data\n",
    "                for key in ['responses', 'rows', 'data', 'items', 'records', 'results', 'survey_responses']:\n",
    "                    if key in data and isinstance(data[key], list):\n",
    "                        df = pd.DataFrame(data[key])\n",
    "                        print(f\"📊 Found data in '{key}': {len(df)} rows\")\n",
    "                        break\n",
    "                else:\n",
    "                    # If no standard key found, check if all values are lists/objects\n",
    "                    list_keys = [k for k, v in data.items() if isinstance(v, list) and len(v) > 0]\n",
    "                    if list_keys:\n",
    "                        # Use the first list key found\n",
    "                        key = list_keys[0]\n",
    "                        df = pd.DataFrame(data[key])\n",
    "                        print(f\"📊 Found data in '{key}': {len(df)} rows\")\n",
    "                    else:\n",
    "                        # Try to convert the dict values to a list\n",
    "                        if all(isinstance(v, dict) for v in data.values()):\n",
    "                            df = pd.DataFrame(list(data.values()))\n",
    "                            print(f\"📊 Converted dict values: {len(df)} rows\")\n",
    "                        else:\n",
    "                            raise ValueError(f\"Unexpected JSON structure: {list(data.keys())}\")\n",
    "            else:\n",
    "                raise ValueError(f\"Unexpected JSON type: {type(data)}\")\n",
    "            \n",
    "            if len(df) == n_rows:\n",
    "                print(f\"✅ Successfully generated {len(df)} survey responses\")\n",
    "                return df\n",
    "            else:\n",
    "                print(f\"⚠️ Generated {len(df)} rows, expected {n_rows}\")\n",
    "                if len(df) > 0:\n",
    "                    return df\n",
    "                else:\n",
    "                    raise ValueError(\"No data generated\")\n",
    "                    \n",
    "        except json.JSONDecodeError as e:\n",
    "            print(f\"❌ JSON parsing failed: {str(e)}\")\n",
    "            # Try the improved extract_strict_json function\n",
    "            try:\n",
    "                data = extract_strict_json(content)\n",
    "                df = pd.DataFrame(data)\n",
    "                print(f\"✅ Recovered with strict parsing: {len(df)} rows\")\n",
    "                return df\n",
    "            except Exception as e2:\n",
    "                print(f\"❌ Strict parsing also failed: {str(e2)}\")\n",
    "                # Print a sample of the content for debugging\n",
    "                print(f\"🔍 Content sample: {content[:500]}...\")\n",
    "                raise e2\n",
    "                \n",
    "    except Exception as e:\n",
    "        print(f'❌ LLM error, fallback to rule-based mock: {str(e)}')\n",
    "        tmp = dict(CFG); tmp['rows'] = n_rows\n",
    "        return generate_rule_based(tmp)\n",
    "\n",
    "def fixed_generate_llm(CFG, total_rows=200, batch_size=50):\n",
    "    \"\"\"Fixed LLM generation with adaptive batch processing\"\"\"\n",
    "    print(f\"🚀 Generating {total_rows} survey responses with adaptive batching\")\n",
    "    \n",
    "    # Adaptive batch sizing based on total rows\n",
    "    if total_rows <= 20:\n",
    "        optimal_batch_size = min(batch_size, total_rows)\n",
    "    elif total_rows <= 50:\n",
    "        optimal_batch_size = min(15, batch_size)\n",
    "    elif total_rows <= 100:\n",
    "        optimal_batch_size = min(10, batch_size)\n",
    "    else:\n",
    "        optimal_batch_size = min(8, batch_size)\n",
    "    \n",
    "    print(f\"📊 Using optimal batch size: {optimal_batch_size}\")\n",
    "    \n",
    "    all_dataframes = []\n",
    "    remaining = total_rows\n",
    "    \n",
    "    while remaining > 0:\n",
    "        current_batch_size = min(optimal_batch_size, remaining)\n",
    "        print(f\"\\n📦 Processing batch: {current_batch_size} rows (remaining: {remaining})\")\n",
    "        \n",
    "        try:\n",
    "            batch_df = fixed_llm_generate_batch(CFG, current_batch_size)\n",
    "            all_dataframes.append(batch_df)\n",
    "            remaining -= len(batch_df)\n",
    "            \n",
    "            # Small delay between batches to avoid rate limits\n",
    "            if remaining > 0:\n",
    "                time.sleep(1.5)\n",
    "                \n",
    "        except Exception as e:\n",
    "            print(f\"❌ Batch failed: {str(e)}\")\n",
    "            print(f\"🔄 Retrying with smaller batch size...\")\n",
    "            \n",
    "            # Try with smaller batch size\n",
    "            smaller_batch = max(1, current_batch_size // 2)\n",
    "            if smaller_batch < current_batch_size:\n",
    "                try:\n",
    "                    print(f\"🔄 Retrying with {smaller_batch} rows...\")\n",
    "                    batch_df = fixed_llm_generate_batch(CFG, smaller_batch)\n",
    "                    all_dataframes.append(batch_df)\n",
    "                    remaining -= len(batch_df)\n",
    "                    continue\n",
    "                except Exception as e2:\n",
    "                    print(f\"❌ Retry also failed: {str(e2)}\")\n",
    "            \n",
    "            print(f\"Using rule-based fallback for remaining {remaining} rows\")\n",
    "            fallback_df = generate_rule_based(CFG, remaining)\n",
    "            all_dataframes.append(fallback_df)\n",
    "            break\n",
    "    \n",
    "    if all_dataframes:\n",
    "        result = pd.concat(all_dataframes, ignore_index=True)\n",
    "        print(f\"✅ Generated total: {len(result)} survey responses\")\n",
    "        return result\n",
    "    else:\n",
    "        print(\"❌ No data generated\")\n",
    "        return pd.DataFrame()\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e1af410e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Test the fixed LLM generation\n",
    "print(\"🧪 Testing LLM generation...\")\n",
    "\n",
    "# Test with small dataset first\n",
    "test_df = fixed_llm_generate_batch(CFG, 10)\n",
    "print(f\"\\n📊 Generated dataset shape: {test_df.shape}\")\n",
    "print(f\"\\n📋 First few rows:\")\n",
    "print(test_df.head())\n",
    "print(f\"\\n📈 Data types:\")\n",
    "print(test_df.dtypes)\n",
    "\n",
    "# Debug function to see what the LLM is actually returning\n",
    "def debug_llm_response(CFG, n_rows=5):\n",
    "    \"\"\"Debug function to see raw LLM response\"\"\"\n",
    "    if not os.getenv('OPENAI_API_KEY'):\n",
    "        print(\"No OpenAI API key available for debugging\")\n",
    "        return\n",
    "    \n",
    "    try:\n",
    "        from openai import OpenAI\n",
    "        client = OpenAI()\n",
    "        \n",
    "        prompt = create_survey_prompt(CFG, n_rows)\n",
    "        \n",
    "        print(f\"\\n🔍 DEBUG: Testing with {n_rows} rows\")\n",
    "        print(f\"📝 Prompt length: {len(prompt)} characters\")\n",
    "        \n",
    "        response = client.chat.completions.create(\n",
    "            model='gpt-4o-mini',\n",
    "            messages=[\n",
    "                {'role': 'system', 'content': 'You are a data generation expert. Generate realistic survey data in JSON format.'},\n",
    "                {'role': 'user', 'content': prompt}\n",
    "            ],\n",
    "            temperature=0.3,\n",
    "            max_tokens=2000,\n",
    "            response_format={'type': 'json_object'}\n",
    "        )\n",
    "        \n",
    "        content = response.choices[0].message.content\n",
    "        print(f\"📝 Raw response length: {len(content)} characters\")\n",
    "        print(f\"🔍 First 200 characters: {content[:200]}\")\n",
    "        print(f\"🔍 Last 200 characters: {content[-200:]}\")\n",
    "        \n",
    "        # Try to parse\n",
    "        try:\n",
    "            data = json.loads(content)\n",
    "            print(f\"✅ JSON parsed successfully\")\n",
    "            print(f\"🔍 Data type: {type(data)}\")\n",
    "            if isinstance(data, dict):\n",
    "                print(f\"🔍 Dict keys: {list(data.keys())}\")\n",
    "            elif isinstance(data, list):\n",
    "                print(f\"🔍 List length: {len(data)}\")\n",
    "        except Exception as e:\n",
    "            print(f\"❌ JSON parsing failed: {str(e)}\")\n",
    "            \n",
    "    except Exception as e:\n",
    "        print(f\"❌ Debug failed: {str(e)}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "75c90739",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Test the fixed implementation\n",
    "print(\"🧪 Testing the fixed LLM generation...\")\n",
    "\n",
    "# Test with small dataset\n",
    "test_df = fixed_llm_generate_batch(CFG, 5)\n",
    "print(f\"\\n📊 Generated dataset shape: {test_df.shape}\")\n",
    "print(f\"\\n📋 First few rows:\")\n",
    "print(test_df.head())\n",
    "print(f\"\\n📈 Data types:\")\n",
    "print(test_df.dtypes)\n",
    "\n",
    "if not test_df.empty:\n",
    "    print(f\"\\n✅ SUCCESS! LLM generation is now working!\")\n",
    "    print(f\"📊 Generated {len(test_df)} survey responses using LLM\")\n",
    "else:\n",
    "    print(f\"\\n❌ Still having issues with LLM generation\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dd83b842",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Test larger dataset generation \n",
    "print(\"🚀 Testing larger dataset generation...\")\n",
    "large_df = fixed_generate_llm(CFG, total_rows=100, batch_size=25)\n",
    "if not large_df.empty:\n",
    "    print(f\"\\n📊 Large dataset shape: {large_df.shape}\")\n",
    "    print(f\"\\n📈 Summary statistics:\")\n",
    "    print(large_df.describe())\n",
    "    \n",
    "    # Save the results\n",
    "    from pathlib import Path\n",
    "    out = Path(\"data\"); out.mkdir(exist_ok=True)\n",
    "    ts = datetime.utcnow().strftime(\"%Y%m%dT%H%M%SZ\")\n",
    "    csv_path = out / f\"survey_llm_fixed_{ts}.csv\"\n",
    "    large_df.to_csv(csv_path, index=False)\n",
    "    print(f\"💾 Saved: {csv_path}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6029d3e2",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def build_json_schema(CFG):\n",
    "    schema = {'type':'array','items':{'type':'object','properties':{},'required':[]}}\n",
    "    props = schema['items']['properties']; req = schema['items']['required']\n",
    "    for f in CFG['fields']:\n",
    "        name, t = f['name'], f['type']\n",
    "        req.append(name)\n",
    "        if t in ('int','float'): props[name] = {'type':'number' if t=='float' else 'integer'}\n",
    "        elif t == 'enum': props[name] = {'type':'string','enum': f['values']}\n",
    "        elif t in ('uuid4','datetime'): props[name] = {'type':'string'}\n",
    "        elif t == 'bool': props[name] = {'type':'boolean'}\n",
    "        else: props[name] = {'type':'string'}\n",
    "    return schema\n",
    "\n",
    "PROMPT_PREAMBLE = (\n",
    "    \"You are a data generator. Return ONLY JSON. \"\n",
    "    \"Respond as a JSON object with key 'rows' whose value is an array of exactly N objects. \"\n",
    "    \"No prose, no code fences, no trailing commas.\"\n",
    ")\n",
    "\n",
    "def render_prompt(CFG, n_rows=100):\n",
    "    minimal_cfg = {'fields': []}\n",
    "    for f in CFG['fields']:\n",
    "        base = {k: f[k] for k in ['name','type'] if k in f}\n",
    "        if 'min' in f and 'max' in f: base.update({'min': f['min'], 'max': f['max']})\n",
    "        if 'values' in f: base.update({'values': f['values']})\n",
    "        if 'fmt' in f: base.update({'fmt': f['fmt']})\n",
    "        minimal_cfg['fields'].append(base)\n",
    "    return {\n",
    "        'preamble': PROMPT_PREAMBLE,\n",
    "        'n_rows': n_rows,\n",
    "        'schema': build_json_schema(CFG),\n",
    "        'constraints': minimal_cfg,\n",
    "        'instruction': f\"Return ONLY this structure: {{'rows': [ ... exactly {n_rows} objects ... ]}}\"\n",
    "    }\n",
    "\n",
    "def parse_llm_json_to_df(raw: str) -> pd.DataFrame:\n",
    "    try:\n",
    "        obj = json.loads(raw)\n",
    "        if isinstance(obj, dict) and isinstance(obj.get('rows'), list):\n",
    "            return pd.DataFrame(obj['rows'])\n",
    "    except Exception:\n",
    "        pass\n",
    "    data = extract_strict_json(raw)\n",
    "    return pd.DataFrame(data)\n",
    "\n",
    "USE_LLM = bool(os.getenv('OPENAI_API_KEY'))\n",
    "print('LLM available:', USE_LLM)\n",
    "\n",
    "def llm_generate_batch(CFG, n_rows=50):\n",
    "    if USE_LLM:\n",
    "        try:\n",
    "            from openai import OpenAI\n",
    "            client = OpenAI()\n",
    "            prompt = json.dumps(render_prompt(CFG, n_rows))\n",
    "            resp = client.chat.completions.create(\n",
    "                model='gpt-4o-mini',\n",
    "                response_format={'type': 'json_object'},\n",
    "                messages=[\n",
    "                    {'role':'system','content':'You output strict JSON only.'},\n",
    "                    {'role':'user','content': prompt}\n",
    "                ],\n",
    "                temperature=0.2,\n",
    "                max_tokens=8192,\n",
    "            )\n",
    "            raw = resp.choices[0].message.content\n",
    "            try:\n",
    "                return parse_llm_json_to_df(raw)\n",
    "            except Exception:\n",
    "                stricter = (\n",
    "                    prompt\n",
    "                    + \"\\nReturn ONLY a JSON object structured as: \"\n",
    "                    + \"{\\\"rows\\\": [ ... exactly N objects ... ]}. \"\n",
    "                    + \"No prose, no explanations.\"\n",
    "                )\n",
    "                resp2 = client.chat.completions.create(\n",
    "                    model='gpt-4o-mini',\n",
    "                    response_format={'type': 'json_object'},\n",
    "                    messages=[\n",
    "                        {'role':'system','content':'You output strict JSON only.'},\n",
    "                        {'role':'user','content': stricter}\n",
    "                    ],\n",
    "                    temperature=0.2,\n",
    "                    max_tokens=8192,\n",
    "                )\n",
    "                raw2 = resp2.choices[0].message.content\n",
    "                return parse_llm_json_to_df(raw2)\n",
    "        except Exception as e:\n",
    "            print('LLM error, fallback to rule-based mock:', e)\n",
    "    tmp = dict(CFG); tmp['rows'] = n_rows\n",
    "    return generate_rule_based(tmp)\n",
    "\n",
    "def generate_llm(CFG, total_rows=200, batch_size=50):\n",
    "    dfs = []; remaining = total_rows\n",
    "    while remaining > 0:\n",
    "        b = min(batch_size, remaining)\n",
    "        dfs.append(llm_generate_batch(CFG, n_rows=b))\n",
    "        remaining -= b\n",
    "        time.sleep(0.2)\n",
    "    return pd.concat(dfs, ignore_index=True)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2e759087",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_llm = generate_llm(CFG, total_rows=100, batch_size=50)\n",
    "df_llm.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6d4908ad",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Test the improved LLM generation with adaptive batching\n",
    "print(\"🧪 Testing improved LLM generation with adaptive batching...\")\n",
    "\n",
    "# Test with smaller dataset first\n",
    "print(\"\\n📦 Testing small batch (10 rows)...\")\n",
    "small_df = fixed_llm_generate_batch(CFG, 10)\n",
    "print(f\"✅ Small batch result: {len(small_df)} rows\")\n",
    "\n",
    "# Test with medium dataset using adaptive batching\n",
    "print(\"\\n📦 Testing medium dataset (30 rows) with adaptive batching...\")\n",
    "medium_df = fixed_generate_llm(CFG, total_rows=30, batch_size=15)\n",
    "print(f\"✅ Medium dataset result: {len(medium_df)} rows\")\n",
    "\n",
    "if not medium_df.empty:\n",
    "    print(f\"\\n📊 Dataset shape: {medium_df.shape}\")\n",
    "    print(f\"\\n📋 First few rows:\")\n",
    "    print(medium_df.head())\n",
    "    \n",
    "    # Save the results\n",
    "    from pathlib import Path\n",
    "    out = Path(\"data\"); out.mkdir(exist_ok=True)\n",
    "    ts = datetime.utcnow().strftime(\"%Y%m%dT%H%M%SZ\")\n",
    "    csv_path = out / f\"survey_adaptive_batch_{ts}.csv\"\n",
    "    medium_df.to_csv(csv_path, index=False)\n",
    "    print(f\"💾 Saved: {csv_path}\")\n",
    "else:\n",
    "    print(\"❌ Medium dataset generation failed\")\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
